package io.github.yezhihao.cipher;

import sun.security.ec.ECPrivateKeyImpl;
import sun.security.ec.ECPublicKeyImpl;
import sun.security.jca.JCAUtil;
import sun.security.util.ECUtil;

import java.math.BigInteger;
import java.security.*;
import java.security.spec.ECGenParameterSpec;
import java.security.spec.ECParameterSpec;
import java.security.spec.ECPoint;

/**
 * ECDSA-secp256k1 签名工具
 */
public class ECDSAUtils {

    private static final String SIGNATURE_ALGORITHM = "SHA256withECDSA";

    private static final String KEY_ALGORITHM = "EC";

    private static final ECParameterSpec PARAMETER_SPEC = ECUtil.getECParameterSpec(null, "secp256k1");
    private static final ECGenParameterSpec GEN_PARAMETER_SPEC = new ECGenParameterSpec("secp256k1");

    /**
     * 生成密钥对
     */
    public static KeyPair createKeyPair() {
        return createKeyPair(JCAUtil.getSecureRandom());
    }

    /**
     * 根据指定的随机种子，生成密钥对
     */
    public static KeyPair createKeyPair(SecureRandom secureRandom) {
        try {
            KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance(KEY_ALGORITHM);
            keyPairGenerator.initialize(GEN_PARAMETER_SPEC);

            KeyPair keyPair = keyPairGenerator.generateKeyPair();
            return keyPair;
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * 根据公钥的十六进制字符串，获取公钥对象
     * 公钥参数固定secp256k1，公钥字符长度应为128
     * @throws InvalidKeyException
     */
    public static PublicKey getPublicKey(String publicKey) throws InvalidKeyException {
        if (publicKey == null || publicKey.length() != 128)
            throw new InvalidKeyException("正确的公钥长度应为128");

        ECPoint w = new ECPoint(
                new BigInteger(publicKey.substring(0, 64), 16),
                new BigInteger(publicKey.substring(64, 128), 16)
        );
        return new ECPublicKeyImpl(w, PARAMETER_SPEC);
    }

    /**
     * 根据公钥的十六进制字符串，获取公钥对象
     * 私钥参数固定secp256k1，私钥字符长度应为64
     * @throws InvalidKeyException
     */
    public static PrivateKey getPrivateKey(String privateKey) throws InvalidKeyException {
        if (privateKey == null || privateKey.isEmpty())
            throw new InvalidKeyException("私钥不能为空");

        BigInteger s = new BigInteger(privateKey, 16);
        return new ECPrivateKeyImpl(s, PARAMETER_SPEC);
    }

    /**
     * 签名
     * @param data       数据体
     * @param privateKey 私钥
     * @return 签名
     * @throws Exception
     */
    public static byte[] sign(PrivateKey privateKey, byte[] data) throws Exception {
        Signature signature = Signature.getInstance(SIGNATURE_ALGORITHM);
        signature.initSign(privateKey);
        signature.update(data);
        return signature.sign();
    }

    /**
     * 验签
     * @param data      数据体
     * @param sign      签名
     * @param publicKey 公钥
     * @return 签名是否正确
     * @throws Exception
     */
    public static boolean verify(PublicKey publicKey, byte[] data, byte[] sign) throws Exception {
        Signature signature = Signature.getInstance(SIGNATURE_ALGORITHM);
        signature.initVerify(publicKey);
        signature.update(data);
        return signature.verify(sign);
    }

    /**
     * 添加签名的TLV格式信息
     */
    public static byte[] addSpecInfo(byte[] sign) {
        boolean sZero = Byte.toUnsignedInt(sign[32]) > 127;
        boolean rZero = Byte.toUnsignedInt(sign[0]) > 127;

        boolean sLen = sign[32] != 31;
        boolean rLen = sign[0] != 31 || Byte.toUnsignedInt(sign[32]) < 31;

        byte length = 2 + 1 + 32 + 1 + 32;

        if (sZero)
            length += 1;
        if (rZero)
            length += 1;

        if (sLen)
            length += 1;
        if (rLen)
            length += 1;

        byte index = length;

        byte[] result = new byte[length];

        System.arraycopy(sign, 32, result, index -= 32, 32);
        if (sLen) {
            if (sZero) {
                result[index -= 1] = 0;
                result[index -= 1] = 33;
            } else {
                result[index -= 1] = 32;
            }
        }
        result[index -= 1] = 2;

        System.arraycopy(sign, 0, result, index -= 32, 32);
        if (rLen) {
            if (rZero) {
                result[index -= 1] = 0;
                result[index -= 1] = 33;
            } else {
                result[index -= 1] = 32;
            }
        }
        result[index -= 1] = 2;

        result[index -= 1] = (byte) (length - 2);
        result[index -= 1] = 48;
        return result;
    }

    /**
     * 删除签名的TLV格式信息
     */
    public static byte[] removeSpecInfo(byte[] sign) {
        int last1 = sign.length;
        int last2 = sign[3] + 4;

        byte[] result = new byte[64];

        System.arraycopy(sign, last2 - 32, result, 0, 32);
        System.arraycopy(sign, last1 - 32, result, 32, 32);
        return result;
    }
}